from dataclasses import dataclass

import pytest
from inline_snapshot import snapshot
from pydantic import BaseModel, ValidationInfo, field_validator

from pydantic_ai import (
    Agent,
    ModelMessage,
    ModelResponse,
    NativeOutput,
    PromptedOutput,
    RunContext,
    TextPart,
    ToolCallPart,
    ToolOutput,
)
from pydantic_ai._output import OutputSpec
from pydantic_ai.models.function import AgentInfo, FunctionModel


class Value(BaseModel):
    x: int

    @field_validator('x')
    def increment_value(cls, value: int, info: ValidationInfo):
        return value + (info.context or 0)


@dataclass
class Deps:
    increment: int


@pytest.mark.parametrize(
    'output_type',
    [
        Value,
        ToolOutput(Value),
        NativeOutput(Value),
        PromptedOutput(Value),
    ],
    ids=[
        'Value',
        'ToolOutput(Value)',
        'NativeOutput(Value)',
        'PromptedOutput(Value)',
    ],
)
def test_agent_output_with_validation_context(output_type: OutputSpec[Value]):
    """Test that the output is validated using the validation context"""

    def mock_llm(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
        if isinstance(output_type, ToolOutput):
            return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args={'x': 0})])
        else:
            text = Value(x=0).model_dump_json()
            return ModelResponse(parts=[TextPart(content=text)])

    agent = Agent(
        FunctionModel(mock_llm),
        output_type=output_type,
        deps_type=Deps,
        validation_context=lambda ctx: ctx.deps.increment,
    )

    result = agent.run_sync('', deps=Deps(increment=10))
    assert result.output.x == snapshot(10)


def test_agent_tool_call_with_validation_context():
    """Test that the argument passed to the tool call is validated using the validation context."""

    agent = Agent(
        'test',
        deps_type=Deps,
        validation_context=lambda ctx: ctx.deps.increment,
    )

    @agent.tool
    def get_value(ctx: RunContext[Deps], v: Value) -> int:
        # NOTE: The test agent calls this tool with Value(x=0) which should then have been influenced by the validation context through the `increment_value` field validator
        assert v.x == ctx.deps.increment
        return v.x

    result = agent.run_sync('', deps=Deps(increment=10))
    assert result.output == snapshot('{"get_value":10}')


def test_agent_output_function_with_validation_context():
    """Test that the argument passed to the output function is validated using the validation context."""

    def get_value(v: Value) -> int:
        return v.x

    agent = Agent(
        'test',
        output_type=get_value,
        deps_type=Deps,
        validation_context=lambda ctx: ctx.deps.increment,
    )

    result = agent.run_sync('', deps=Deps(increment=10))
    assert result.output == snapshot(10)


def test_agent_output_validator_with_validation_context():
    """Test that the argument passed to the output validator is validated using the validation context."""

    agent = Agent(
        'test',
        output_type=Value,
        deps_type=Deps,
        validation_context=lambda ctx: ctx.deps.increment,
    )

    @agent.output_validator
    def identity(ctx: RunContext[Deps], v: Value) -> Value:
        return v

    result = agent.run_sync('', deps=Deps(increment=10))
    assert result.output.x == snapshot(10)


def test_agent_output_validator_with_intermediary_deps_change_and_validation_context():
    """Test that the validation context is updated as run dependencies are mutated."""

    agent = Agent(
        'test',
        output_type=Value,
        deps_type=Deps,
        validation_context=lambda ctx: ctx.deps.increment,
    )

    @agent.tool
    def bump_increment(ctx: RunContext[Deps]):
        assert ctx.validation_context == snapshot(10)  # validation ctx was first computed using the original deps
        ctx.deps.increment += 5  # update the deps

    @agent.output_validator
    def identity(ctx: RunContext[Deps], v: Value) -> Value:
        assert ctx.validation_context == snapshot(15)  # validation ctx was re-computed after deps update from tool call

        return v

    result = agent.run_sync('', deps=Deps(increment=10))
    assert result.output.x == snapshot(15)
